package com.jstarcraft.ai.jsat.math.optimization;

import com.jstarcraft.ai.jsat.linear.IndexValue;
import com.jstarcraft.ai.jsat.linear.Matrix;
import com.jstarcraft.ai.jsat.linear.Vec;
import com.jstarcraft.ai.jsat.math.Function;
import com.jstarcraft.ai.jsat.math.FunctionVec;

/**
 * Implementation of the Broyden–Fletcher–Goldfarb–Shanno (BFGS) algorithm for
 * function minimization. For {@code n} dimensional problems it requires
 * <i>O(n<sup>2</sup>)</i> work per iteration and uses first order information
 * to approximate the Hessian.
 * 
 * @author Edward Raff
 */
public class BFGS implements Optimizer {
    private LineSearch lineSearch;
    private int maxIterations;
    private boolean inftNormCriterion = true;

    /**
     * Creates a new BFGS optimization object that uses a maximum of 250 iterations
     * and a {@link BacktrackingArmijoLineSearch backtracking} line search.
     */
    public BFGS() {
        this(250, new BacktrackingArmijoLineSearch());
    }

    /**
     * Creates a new BFGS optimization object
     * 
     * @param maxIterations the maximum number of iterations to allow before
     *                      stopping
     * @param lineSearch    the line search method to use on updates
     */
    public BFGS(int maxIterations, LineSearch lineSearch) {
        setMaximumIterations(maxIterations);
        setLineSearch(lineSearch);
    }

    @Override
    public void optimize(double tolerance, Vec w, Vec x0, Function f, FunctionVec fp, boolean parallel) {
        if (fp == null)
            fp = Function.forwardDifference(f);
        LineSearch search = lineSearch.clone();

        Matrix H = Matrix.eye(x0.length());
        Vec x_prev = x0.clone();
        Vec x_cur = x0.clone();
        final double[] f_xVal = new double[1];// store place for f_x

        // graidnet
        Vec x_grad = x0.clone();
        x_grad.zeroOut();
        Vec x_gradPrev = x_grad.clone();
        // p_l
        Vec p_k = x_grad.clone();
        Vec s_k = x_grad.clone();
        Vec y_k = x_grad.clone();

        f_xVal[0] = f.f(x_cur);
        x_grad = fp.f(x_cur, x_grad);

        int iter = 0;
        while (gradConvgHelper(x_grad) > tolerance && iter < maxIterations) {
            iter++;
            p_k.zeroOut();
            H.multiply(x_grad, -1, p_k);// p_k = −H_k ∇f_k; (6.18)

            // Set x_k+1 = x_k + α_k p_k where α_k is computed from a line search
            x_cur.copyTo(x_prev);
            x_grad.copyTo(x_gradPrev);

            double alpha_k = search.lineSearch(1.0, x_prev, x_gradPrev, p_k, f, fp, f_xVal[0], x_gradPrev.dot(p_k), x_cur, f_xVal, x_grad, parallel);
            if (alpha_k < 1e-12 && iter > 5)// if we are making near epsilon steps consider it done
                break;

            if (!search.updatesGrad())
                fp.f(x_cur, x_grad, parallel);

            // Define s_k =x_k+1 −x_k and y_k = ∇f_k+1 −∇f_k;
            x_cur.copyTo(s_k);
            s_k.mutableSubtract(x_prev);

            x_grad.copyTo(y_k);
            y_k.mutableSubtract(x_gradPrev);
            // Compute H_k+1 by means of (6.17);

            double skyk = s_k.dot(y_k);
            if (skyk <= 0) {
                H.zeroOut();
                for (int i = 0; i < H.rows(); i++)
                    H.set(i, i, 1);
                continue;
            }
            if (iter == 0 && skyk > 1e-12)
                for (int i = 0; i < H.rows(); i++)
                    H.set(i, i, skyk / y_k.dot(y_k));

            /*
             * From "A Perfect Example for The BFGS Method" equation 1.5 aamath:
             * H_(k+1)=H_k-(s_k*y_k^T*H_k+H_k*y_k*s_k^T)/(s_k^T*y_k)+(1+(y_k^T*H_k*y_k)/(s_k
             * ^T*y_k))*((s_k*s_k^T)/(s_k^T*y_k))
             * 
             * T T / T \ T s y H + H y s | y H y | s s k k k k k k | k k k| k k H = H -
             * ------------------- + |1 + --------| ----- k + 1 k T | T | T s y | s y | s y
             * k k \ k k / k k
             * 
             * TODO: y_k^T H_k y_k should be just a scalar constant TODO: exploit the
             * symetry of H_k
             */

            Vec Hkyk = H.multiply(y_k);
            Vec ykHk = y_k.multiply(H);
            double b = (1 + y_k.dot(Hkyk) / skyk) / skyk;// coef for right rank update

            // update
            Matrix.OuterProductUpdate(H, s_k, ykHk, -1 / skyk);
            Matrix.OuterProductUpdate(H, Hkyk, s_k, -1 / skyk);
            Matrix.OuterProductUpdate(H, s_k, s_k, b);
        }

        x_cur.copyTo(w);
    }

    /**
     * By default the infinity norm is used to judge convergence. If set to
     * {@code false}, the 2 norm will be used instead.
     * 
     * @param inftNormCriterion
     */
    public void setInftNormCriterion(boolean inftNormCriterion) {
        this.inftNormCriterion = inftNormCriterion;
    }

    /**
     * Returns whether or not the infinity norm ({@code true}) or 2 norm
     * ({@code false}) is used to determine convergence.
     * 
     * @return {@code true} if the infinity norm is in use, {@code false} for the 2
     *         norm
     */
    public boolean isInftNormCriterion() {
        return inftNormCriterion;
    }

    private double gradConvgHelper(Vec grad) {
        if (!inftNormCriterion)
            return grad.pNorm(2);
        double max = 0;
        for (IndexValue iv : grad)
            max = Math.max(max, Math.abs(iv.getValue()));
        return max;
    }

    @Override
    public void setMaximumIterations(int iterations) {
        if (iterations < 1)
            throw new IllegalArgumentException("Iterations must be a positive value, not " + iterations);
        this.maxIterations = iterations;
    }

    @Override
    public int getMaximumIterations() {
        return maxIterations;
    }

    /**
     * Sets the line search method used at each iteration
     * 
     * @param lineSearch the line search method used at each iteration
     */
    public void setLineSearch(LineSearch lineSearch) {
        this.lineSearch = lineSearch;
    }

    /**
     * Returns the line search method used at each iteration
     * 
     * @return the line search method used at each iteration
     */
    public LineSearch getLineSearch() {
        return lineSearch;
    }

    @Override
    public Optimizer clone() {
        return new BFGS(maxIterations, lineSearch.clone());
    }

}
